A softmax evening

"Its you - The Humans who discriminate, not Us" - said the old wise machine with beaming eyes.

Softmax is one of the most used operations in ML/DL literature. In this post we would implement a similar softmax implementation from scratch. Idea is to write a fast, dependency-free implementation for personal purposes (possibly to be integrated later into a larger library).

Softmax name is a bit misleading given this operation behaves more like a softer argmax. Argmax routine over ordered K values aka a K length array is supposed to return the index of maximum value in that particular array. If we output 1 for only the maximum value and 0 for rest of them, it would become an hot-encoded vector. Any such vector could be seen as a form of hard-encoding.

But in DL literature we prefer nice smooth differentiable functions like exponential, log or square of stuff. So softmax is supposed to be softer version of above discussed definition of argmax. Rather than generating zeros and one, it is supposed to transform original K values into a (smoother) probability distribution, where each of the transformed K value represents a weightage for each of the respective original value.

It also makes sense to smoothen the original distribution ,for example if original values differed by only a small margin argmax wouldn't make a lot of sense as we would be throwing a lot of information away. Functions like max and min lie on the extreme end of information spectrum and one of the reasons maxpooling layers have substantial effects on final architecture performance!

Since we still wish to assign largest weightage to maximum value and smallest to minimum value, obvious choice would be to use a monotonically increasing function. Another name of this operation is normalized exponential function.

How:

Current machines are digital machines with much higher precision capabilities than older ones but limited nonetheless. Assuming we are using float32 to store real numbers we will apply some tricks to generate the probability distribution.

Exponential function originally has Real numbers as its domain but instead we will limit its domain to just negative numbers for our use case, hence making sure range remains less than 1. To do that we can minus the maximum of the values from each of the original value before taking exponent.

But exp routine provided by libC is not fast enough, so using that implementation for large number of values would introduce latency we may not like! It is not that original exp implementation is doing some unnecessary logic but rather it is written to support any real value as its input. But now we know that our input value to the exp function would be <= zero, we would leverage this domain-specific-knowledge to write a simpler and faster implementation for exp operation.

def expSimple(x:float) -> float:
    # less precise implementation for exponential operation.
    # supposed to be very precise for x < 3.

    x = max(x, -1024)
    result = (1 + (x / 1024))
    result = result * result
    result = result * result
    result = result * result
    result = result * result
    result = result * result
    result = result * result
    result = result * result
    result = result * result
    result = result * result
    result = result * result
    return result

alt comparison

Above implementation is not anything special, almost all routines including exp, sin, cos are results of some series expansion and hence can also be implemented using additions and multiplications to achieve a desired precision. We can think of above implementation as a function that behaves exactly like exponent function for negative values atleast but is much faster to calculate. Also implementation is much easier to vectorize to get further speed up for systems with such hardware.

For final transformation we would need to find the max of the original array before we could take exp of those values, and later we would also need the total sum of the transformed values.

def transform(original:float, max:float, sumexp:float)->float:
    result = expSimple(original - max) / sumexp 

We can calculate the max by scanning each of the value in array/list and keeping track of the maximum value encountered so far.

def get_max(data:List[float]) -> float:
    prev_max = -inf  # a large negative value.
    for d in data:
        if (d) > prev_max:
            prev_max = d
    return prev_max

Thing to note here is that at any point during array traversal we have access to largest number encountered so far.

def get_sumexp(data:List[float], max:float) -> float:
    sumexp = 0
    for d in data:
        sumexp += exp(data - max)
    return sumexp
def accumulate_sumexp(current_data:float, sumexp:float, prev_max:float) -> Tuple[float, float]:

    if (current_data) <= prev_max:
        #new_sumexp = sumexp + exp(data - prev_max)
        sumexp += exp(current_data - prev_max)
    else:
        #new_sumexp = exp(data[0] - d ) + exp(data[1] - d) + .....
        #new_sumexp = exp(data[0] - prev_max - d)*exp(prev_max) + exp(data[1] - prev_max -d)*exp(prev_max)  + ... 
        #new_sumexp = (sumexp)*exp(prev_max) / exp(d) + 1
        sumexp = sumexp*exp(prev_max - current_data) + 1
        prev_max = current_data
    return (prev_max, sumexp)

We can note that that to accumulate or update sumexp at each scan of original array, we just need access to current value, sumexp accumulated so far, and largest number encountered so far. Also we can see that current_data and prev_max are already being calculated during get_max routine. So these observations lead to following routine.

# Also known as streaming softmax implementation.
# we always have a value <= 0 as input to exp so our expSimple assumption remains valid . 
def get_max_sumexp(data:List[float])-> Tuple[float, float]:
    prev_max = -Inf  # a large negative value.
    sumexp = 0
    for d in data:
        if (d > prev_max):
            sumexp = sumexp*(prev_max - d) + 1         
            prev_max = d
        else:
            sumexp += exp(d - prev_max) 

We now only access each element once and update the max and sumexp in streaming manner. This reduces the costly memory load operations as we only access each element in the data array once rather than twice. If data already reside at contiguous locations in memory very few actual memory load operations should be needed as most of requests would be fulfilled by a data cache.

Exact speedup would depend on the cache hierarchy and actual data-layout. In my experiments testing with a packed (512 x 512) float32 tensor (along last dimension), streaming softmax results in 3% faster execution over non-streaming softmax. (keeping exp implementation constant)

Speed gains for streaming implementation would be much visible for systems without cache than ones with one or more caches.

Some algorithms initially were written for older systems where every extra instruction would have a significant cost but may not so visible for current faster but more complex pipelines. So always measure even when you know stuff !

If for a tensor other than 1D, the dimension along which softmax is being calculated is contiguous i.e have stride as 1, then it would be in our favor, as it would allow us to use cache more efficiently. alt cache-ram

Also note that modern computing pipelines are quite complex and actual data movement from RAM to registers (and there within) would depend on a lot of factors such as number of caches, their policies, and actual data-access patterns. Idea should be to divide an algorithm into submodules/subroutines while focussing on data-reuse, minimizing memory store/loads. For me most important factor always is to try to utilize cache more efficiently, other optimizations (multi-threading, vectorization) are much easier to implement and understand.

Of-course we may need to calculate softmax along some non-contiguous dimension too, then we may just have to settle on not being able to cache efficiently. But as i stressed we must always measure and not just assume, we can reorder the data such that desired dimension becomes contiguous, but we would know if extra-computation was of worth after benchmarking only.

One more optimization we can do is multi-threading, i.e running more than one execution flow in fully parallel in case an OS supports it. Modern systems are shipped with more than one computing cores, allowing users to take advantage of parellel execution at the cost of extra-work.

A very simple form of multi-threading is one where code doesn't need to use any lock or sync barrier, hence each execution flow runs to completion independently. Some problems lend easily to this kind of pattern.

For example in our case of 2D tensor, we may need to calculate softmax along last dimension. Each row in the tensor for this operation wouldn't have any dependence on any other row and hence could be used to calculate softmax for that row independently. May be full algorithm not lend to this lockless pattern, but we could refactor the algorithm to expose such pattern for many cases. Executing submodules in this way is also known as fork-join paradigm. A lot more theory is there to use multi-threading efficiently and should be studied if one wishes to use it frequently.

Using Multi-threading may involve writing extra-code if you are using minimalistic API, but once written such code is very portable especially if systems use same threading standard such as POSIX. Using fork-join paradigm would easily allow us to set number of threads at runtime too to experiment depending on number of independent cores on any system. Sure using threads in this way would include some overhead, but as the work for each thread increases this overhead starts amortizing !

Vectorization:

Vectorization allows us to use SIMD (single instruction multiple data) units in hardware to process more than unit of data simultaneously. But such units are supposed to be kept fed by leveraging instruction-parallelism as these instructions generally have higher latency than corresponding SISD (single instruction single data). It may not be possible to refactor code to implement instruction-parallelism every time, so we may not see huge benefits just by implementing vector-instructions but we do generally get some speed up in most of cases.

In this case i chose to calculate max first, as it would lead to easier implementation for VECTORIZATION without any branching in exp(data - max) implementation.

# final softmax representation (pseudo code)
max = get_max(data:List[float])  # generally very fast for a few K elements.

sumexp = get_sumexp_vectorized(data:List[float], max) # exp(data[0]-max) + exp(data[1])-max + .... 

for i in range(len(data)):
    data[i] = (data[i] - max) / sumexp

Results:

I actually implemented everything in pure Nim on my quad-core Intel-i5 (8th generation) machine. Pytorch with version 2.0.0 and torch.softmax() were used for comparison.

Config Matrix (512 x 512) float32 Matrix (1024 x 1024) float32
Threads 8 + vectorization 0.54 ms 1.3 ms
Threads 4 + vectorization 0.53 ms 1.6 ms
Threads 2 + vectorization 0.73 ms 2.5 ms
Pytorch 1.1 ms 3.7 ms

Using our own expSimple implementation in the final softmax implementation produces tensors with absolute maximum difference less than 1e-7 compared to the expected. So all in all we got a good deal!


If you have found this content useful, please consider Donating . It would help me pay my bills and create more of this.